Skip to content

Commit 56357ef

Browse files
committed
support for image attachments when classifying questions
1 parent 42cc62f commit 56357ef

File tree

2 files changed

+43
-5
lines changed

2 files changed

+43
-5
lines changed

kitsune/llm/questions/classifiers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def classify_question(question: "Question") -> dict[str, Any]:
3232
payload: dict[str, Any] = {
3333
"subject": question.title,
3434
"question": question.content,
35+
"image_urls": [image.get_absolute_url() for image in question.get_images()],
3536
"product": product,
3637
"topics": get_taxonomy(
3738
product, include_metadata=["description", "examples"], output_format="JSON"

kitsune/llm/questions/prompt.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
2-
from langchain.prompts import ChatPromptTemplate
2+
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
3+
from langchain.schema import HumanMessage
4+
from langchain.schema.runnable import RunnableLambda
35

46
SPAM_INSTRUCTIONS = """
57
# Role and goal
@@ -120,17 +122,52 @@
120122
)
121123

122124

123-
spam_prompt = ChatPromptTemplate(
125+
spam_prompt_template = ChatPromptTemplate(
124126
(
125127
("system", SPAM_INSTRUCTIONS),
126-
("human", USER_QUESTION),
128+
MessagesPlaceholder("human_message"),
127129
)
128130
).partial(format_instructions=spam_parser.get_format_instructions())
129131

130132

131-
topic_prompt = ChatPromptTemplate(
133+
topic_prompt_template = ChatPromptTemplate(
132134
(
133135
("system", TOPIC_INSTRUCTIONS),
134-
("human", USER_QUESTION),
136+
MessagesPlaceholder("human_message"),
135137
)
136138
).partial(format_instructions=topic_parser.get_format_instructions())
139+
140+
141+
def create_human_message(inputs: dict) -> dict:
142+
"""
143+
Creates the human message, with the image URL's if they're present, and
144+
then adds it to the inputs dict. Returns the modified inputs dict.
145+
"""
146+
image_urls = inputs.pop("image_urls", None)
147+
148+
content: list[dict] = [
149+
{
150+
"type": "text",
151+
"text": USER_QUESTION.format(**inputs),
152+
},
153+
]
154+
155+
if image_urls:
156+
for image_url in image_urls:
157+
content.append(
158+
{
159+
"type": "image_url",
160+
"image_url": {
161+
"url": image_url,
162+
},
163+
}
164+
)
165+
166+
inputs["human_message"] = [HumanMessage(content=content)]
167+
return inputs
168+
169+
170+
spam_prompt = RunnableLambda(create_human_message) | spam_prompt_template
171+
172+
173+
topic_prompt = RunnableLambda(create_human_message) | topic_prompt_template

0 commit comments

Comments
 (0)